{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Concise Implementation of Softmax Regression\n",
    "\n",
    "Just as PyTorch made it much easier to implement linear regression, we'll find it similarly (or possibly more)\n",
    "convenient for implementing classification models. Again, we begin with our import ritual."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, '..')\n",
    "import d2l\n",
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's stick with the Fashion-MNIST dataset and keep the batch size at $256$ as in the last section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Model Parameters\n",
    "\n",
    "As mentioned in `chapter_softmax`, the output layer of softmax regression is a fully connected (`Linear`) layer. Therefore, to implement our model, we just need to add one `Linear` layer with 10 outputs to our `Sequential`. Again, here, the `Sequential` isn't really necessary, but we might as well form the habit since it will be ubiquitous when implementing deep models. Again, we initialize the weights at random with zero mean and standard deviation 0.01."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Reshape()\n",
       "  (1): Linear(in_features=784, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Reshape(torch.nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x.view(-1,784)\n",
    "    \n",
    "net = nn.Sequential(Reshape(), nn.Linear(784, 10))\n",
    "\n",
    "def init_weights(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        torch.nn.init.normal_(m.weight, std=0.01)\n",
    "\n",
    "net.apply(init_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Softmax\n",
    "\n",
    "In the previous example, we calculated our model's output and then ran this\n",
    "output through the cross-entropy loss. Mathematically, that's a perfectly reasonable thing to do. However,\n",
    "computationally, things can get hairy when dealing with exponentiation due to\n",
    "numerical stability issues, a matter we've already discussed a few times\n",
    "(e.g. in `chapter_naive_bayes`) and in the problem set of the previous chapter).\n",
    "Recall that the softmax function calculates $\\hat y_j = \\frac{e^{z_j}}{\\sum_{i=1}^{n} e^{z_i}}$, where $\\hat y_j$\n",
    "is the j-th element of ``yhat`` and $z_j$ is the j-th element of the input\n",
    "``y_linear`` variable, as computed by the softmax.\n",
    "\n",
    "If some of the $z_i$ are very large (i.e. very positive),\n",
    "$e^{z_i}$ might be larger than the largest number\n",
    "we can have for certain types of ``float`` (i.e. overflow).\n",
    "This would make the denominator (and/or numerator) ``inf`` and we get zero,\n",
    "or ``inf``, or ``nan`` for $\\hat y_j$.\n",
    "In any case, we won't get a well-defined return value for ``cross_entropy``. This is the reason we subtract $\\text{max}(z_i)$\n",
    "from all $z_i$ first in ``softmax`` function.\n",
    "You can verify that this shifting in $z_i$\n",
    "will not change the return value of ``softmax``.\n",
    "\n",
    "After the above subtraction/ normalization step,\n",
    "it is possible that $z_j$ is very negative.\n",
    "Thus, $e^{z_j}$ will be very close to zero\n",
    "and might be rounded to zero due to finite precision (i.e underflow),\n",
    "which makes $\\hat y_j$ zero and we get ``-inf`` for $\\text{log}(\\hat y_j)$.\n",
    "A few steps down the road in backpropagation,\n",
    "we start to get horrific not-a-number (``nan``) results printed to screen.\n",
    "\n",
    "Our salvation is that even though we're computing these exponential functions, we ultimately plan to take their log in the cross-entropy functions.\n",
    "It turns out that by combining these two operators\n",
    "``softmax`` and ``cross_entropy`` together,\n",
    "we can escape the numerical stability issues\n",
    "that might otherwise plague us during backpropagation.\n",
    "As shown in the equation below, we avoided calculating $e^{z_j}$\n",
    "but directly used $z_j$ due to $\\log(\\exp(\\cdot))$.\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "\\log{(\\hat y_j)} & = \\log\\left( \\frac{e^{z_j}}{\\sum_{i=1}^{n} e^{z_i}}\\right) \\\\\n",
    "& = \\log{(e^{z_j})}-\\text{log}{\\left( \\sum_{i=1}^{n} e^{z_i} \\right)} \\\\\n",
    "& = z_j -\\log{\\left( \\sum_{i=1}^{n} e^{z_i} \\right)}\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "We'll want to keep the conventional softmax function handy\n",
    "in case we ever want to evaluate the probabilities output by our model.\n",
    "But instead of passing softmax probabilities into our new loss function,\n",
    "we'll just pass $\\hat{y}$ and compute the softmax and its log\n",
    "all at once inside the softmax_cross_entropy loss function,\n",
    "which does smart things like the log-sum-exp trick ([see on Wikipedia](https://en.wikipedia.org/wiki/LogSumExp)).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimization Algorithm\n",
    "\n",
    "We use the mini-batch random gradient descent\n",
    "with a learning rate of $0.1$ as the optimization algorithm.\n",
    "Note that this is the same choice as for linear regression\n",
    "and it illustrates the general applicability of the optimizers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n",
    "Next, we use the training functions defined in the last section to train a model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 0.0031, train acc 0.750, test acc 0.786\n",
      "epoch 2, loss 0.0022, train acc 0.813, test acc 0.812\n",
      "epoch 3, loss 0.0021, train acc 0.826, test acc 0.815\n",
      "epoch 4, loss 0.0020, train acc 0.832, test acc 0.823\n",
      "epoch 5, loss 0.0019, train acc 0.837, test acc 0.809\n",
      "epoch 6, loss 0.0019, train acc 0.840, test acc 0.814\n",
      "epoch 7, loss 0.0018, train acc 0.842, test acc 0.820\n",
      "epoch 8, loss 0.0018, train acc 0.845, test acc 0.830\n",
      "epoch 9, loss 0.0018, train acc 0.847, test acc 0.834\n",
      "epoch 10, loss 0.0018, train acc 0.850, test acc 0.832\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 10\n",
    "lr = 0.1\n",
    "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just as before, this algorithm converges to a solution\n",
    "that achieves an accuracy of 83.3%,\n",
    "albeit this time with a lot fewer lines of code than before.\n",
    "Note that in many cases, PyTorch takes specific precautions\n",
    "in addition to the most well-known tricks for ensuring numerical stability.\n",
    "This saves us from many common pitfalls that might befall us\n",
    "if we were to code all of our models from scratch.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Try adjusting the hyper-parameters, such as batch size, epoch, and learning rate, to see what the results are.\n",
    "1. Why might the test accuracy decrease again after a while? How could we fix this?"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
